MCRI & UniMelb, UGent, MCRI & UniMelb
ATE is a central causal estimand in clinical and epidemiological research
A substantial portion of biostatistics is devoted to analysing observational data.
Existing singly robust causal inference methods face practical challenges: sensitive to model misspecification or unstable under near-positivity violation
ML is increasingly used
But our estimands are causal effects e.g. ATE
Naive “plug-in ML” for ATE can be,
Observed data: for each observation \(W\),
Let \(X = (D, Z)\) and \(W = (Y, X) = (Y, D, Z)\)
Define the outcome regression function \(\gamma\) by \[ \gamma(d, z) = \mathbb{E}[Y \mid D = d, Z = z], \] and let \(\gamma_0\) denote the true regression function.
\[ \theta = E\bigl[ m(W; \gamma) \;+\; \alpha(X)\,\bigl(Y - \gamma(X)\bigr) \bigr]. \]
The function \(\alpha\) is the Riesz representer of the linear functional \(E[m(W; \gamma)]\).
We focus on problems where there exists a square-integrable random variable \(\alpha(X)\) such that: \[ E\bigl[m(W; \gamma)\bigr] \;=\; E\bigl[\alpha(X)\,\gamma(X)\bigr], \quad \text{for all } \gamma \text{ with } E\bigl[\gamma(X)^2\bigr] < \infty. \]
By the Riesz representation theorem, such an \(\gamma(X)\) exists if and only if \(E[m(W; \gamma)]\) is a continuous linear functional of \(\gamma\).
Traditional approach: we need to derive the explicit form of Riesz representer
For ATE, the RR is
\[ \begin{aligned} \frac{D_i}{\hat{\pi}_0(Z_i)} - \frac{(1 - D_i)}{1 - \hat{\pi}_0(Z_i)} \end{aligned} \]
Derive the Riesz representer directly?
estimate \(\widehat{\frac{D}{\pi_0(Z)}}\) rather than \(\frac{D}{\widehat{\pi_0(Z)}}\)
the Riesz representer is the minimiser of the loss function \[ \begin{aligned} \alpha_{0} &= \text{argmin}_{\alpha} \,E\Bigl[\bigl(\alpha(X) - \alpha_{0}(X)\bigr)^2\Bigr] \\[6pt] &= \text{argmin}_{\alpha} \,E\Bigl[ \alpha(X)^2 \;-\; 2\,\alpha_{0}(X)\,\alpha(X) \;+\; \alpha_{0}(X)^2 \Bigr] \\[6pt] &= \text{argmin}_{\alpha} \,E\Bigl[ \alpha(X)^2 \;-\; 2\,m\bigl(W; \alpha\bigr)\Bigr] \\ & = \text{argmin}_{\alpha} \,E\Bigl[ \alpha\bigl(D_i, Z_i\bigr)^2-2\bigl(\alpha(1, Z_i) - \alpha(0, Z_i)\bigr) \Bigr]. \end{aligned} \]
Using this \(\alpha\) to do debiasing is called autoDML
Partition the set of data indices \(1,\ldots,n\) into \(L\) disjoint subsets of about equal size \(\{I_\ell\}_{\ell=1}^L\).
For each data fold \(\ell = 1,\ldots,L\):
Estimate \(\hat{\gamma}_\ell \in \mathcal{G}_n\) as a nonparametric regression of \(Y\) on \(X\), using observations not in \(I_\ell\).
Estimate the debiasing function \(\hat{\alpha}_\ell\) using observations not in \(I_\ell\) by minimizing a sample version of the objective function \[ \hat{\alpha}_\ell = \arg\min_{\alpha \in \mathcal{A}_n} \Biggl\{ \sum_{i \in I_\ell} \Bigl[ -\,2\,m(W_i;\alpha) + \alpha(X_i)^2 \Bigr] + \Lambda_r(\alpha) \Biggr\}, \]
Estimate the parameter of interest using the cross-fitting and debiasing function in the moment function of equation: \[ \hat{\theta} = \frac{1}{n}\sum_{\ell=1}^L \sum_{i \in I_\ell} \Bigl[ m\bigl(W_i;\hat{\gamma}_\ell\bigr) + \hat{\alpha}_\ell\bigl(X_i\bigr)\bigl\{Y_i - \hat{\gamma}_\ell\bigl(X_i\bigr)\bigr\} \Bigr]. \]
Estimate the standard error of \(\hat{\theta}\) as \(\sqrt{\hat{V}/n}\), where \[ \hat{V} = \frac{1}{n}\sum_{\ell=1}^L \sum_{i \in I_\ell} \Bigl[ m\bigl(W_i;\hat{\gamma}_\ell\bigr) + \hat{\alpha}_\ell\bigl(X_i\bigr)\bigl\{Y_i - \hat{\gamma}_\ell\bigl(X_i\bigr)\bigr\} - \hat{\theta} \Bigr]^2. \]
Since \(D \in \{0,1\}\), we can decompose the Riesz representer as
\[
\alpha(D,Z)
\;=\;
D \,\alpha(1,Z)
\;+\;
(1-D)\,\alpha(0,Z).
\]
This suggests we can use Two-headed MLP
We generate 1,000 datasets, each with 2,000 observations.
Covariates:\(Z_1, Z_2, Z_3 \sim N(0,1).\)
True propensity:\(\pi(Z) = \operatorname{expit}\!\left( (2Z_1 - Z_2 + 0.5 Z_3)\right).\)
Treatment variable:\(A \sim \mathrm{Bernoulli}\bigl(\pi (Z))\bigr).\)
Outcome model with known ATE:\(Y = 2 A + Z_1 - Z_2 + 0.5 Z_3 + \varepsilon,\qquad\varepsilon\sim N(0,1).\)
We compare autoDML, DML, and TMLE.
For DML and TMLE:
For autoDML:
The outcome regression is correctly specified.
The Riesz representer is estimated using a two-headed MLP, with a shared 4 common layers and two separate output heads for \(\alpha(1, Z)\) and \(\alpha(0, Z)\).
| Method | ATE Mean | SD | MSE | Coverage |
|---|---|---|---|---|
| TML | 1.997 | 0.088 | 0.008 | 90% |
| AIPW | 1.998 | 0.142 | 0.020 | 97% |
| autoDML | 1.997 | 0.068 | 0.005 | 93% |
autoDML automatically learns the Riesz representer. This greatly simplifies implementation for complex causal parameters.
More stable than standard AIPW estimator
Simple simulations show that learned Riesz weights reduce variance relative to AIPW & TMLE with plug-in models.
Future work will evaluate autoDML in more realistic simulation studies that are directly motivated by cohort studies.